Gemma/[Gemma_2]Using_with_mistral_rs.ipynb (506 lines of code) (raw):
{
"cells": [
{
"cell_type": "markdown",
"metadata": {
"id": "BojqmYOsPk0A"
},
"source": [
"##### Copyright 2024 Google LLC."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"cellView": "form",
"id": "FtMmJ-pvPfNl"
},
"outputs": [],
"source": [
"# @title Licensed under the Apache License, Version 2.0 (the \"License\");\n",
"# you may not use this file except in compliance with the License.\n",
"# You may obtain a copy of the License at\n",
"#\n",
"# https://www.apache.org/licenses/LICENSE-2.0\n",
"#\n",
"# Unless required by applicable law or agreed to in writing, software\n",
"# distributed under the License is distributed on an \"AS IS\" BASIS,\n",
"# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n",
"# See the License for the specific language governing permissions and\n",
"# limitations under the License."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "IyATsAVlPz1W"
},
"source": [
"# Inference with Gemma 2 using mistral.rs\n",
"\n",
"[Gemma](https://ai.google.dev/gemma) is a family of lightweight, state-of-the-art open-source language models from Google. Built from the same research and technology used to create the Gemini models, Gemma models are text-to-text, decoder-only large language models (LLMs), available in English, with open weights, pre-trained variants, and instruction-tuned variants.\n",
"Gemma models are well-suited for various text-generation tasks, including question-answering, summarization, and reasoning. Their relatively small size makes it possible to deploy them in environments with limited resources such as a laptop, desktop, or your cloud infrastructure, democratizing access to state-of-the-art AI models and helping foster innovation for everyone.\n",
"\n",
"[mistral.rs](https://github.com/EricLBuehler/mistral.rs) is a versatile framework for Large Language Model (LLM) inference supporting text-to-text and multimodal LLMs. It offers features like grammar support for structured outputs, and inference using LoRA-fine-tuned models, making it a powerful tool for a wide range of AI applications.\n",
"\n",
"In this notebook, you will learn how to prompt the Gemma 2 model in various ways using the **mistral.rs** Python APIs in a Google Colab environment.\n",
"<table align=\"left\">\n",
" <td>\n",
" <a target=\"_blank\" href=\"https://colab.research.google.com/github/google-gemini/gemma-cookbook/blob/main/Gemma/[Gemma_2]Using_with_mistral_rs.ipynb\"><img src=\"https://www.tensorflow.org/images/colab_logo_32px.png\" />Run in Google Colab</a>\n",
" </td>\n",
"</table>"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "nFgaq_--Qg-O"
},
"source": [
"## Setup\n",
"\n",
"### Select the Colab runtime\n",
"To complete this tutorial, you can use the CPU runtime in Colab. Since the default accelerator for any Colab runtime is CPU, simply click the **Connect** button in the top-right corner of the Colab window.\n",
"\n",
"### Setup Hugging Face\n",
"\n",
"**Before you dive into the tutorial, let's get you set up with Hugging face:**\n",
"\n",
"1. **Hugging Face Account:** If you don't already have one, you can create a free Hugging Face account by clicking [here](https://huggingface.co/join).\n",
"\n",
"2. **Hugging Face Token:** Generate a Hugging Face access (preferably `write` permission) token by clicking [here](https://huggingface.co/settings/tokens). You'll need this token later in the tutorial.\n",
"\n",
"**Once you've completed these steps, you're ready to move on to the next section where you'll set up environment variables in your Colab environment.**"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "bLEUJYZ8QmGz"
},
"source": [
"### Configure your HF token\n",
"Add your Hugging Face token to the Colab Secrets manager to securely store it.\n",
"\n",
"1. Open your Google Colab notebook and click on the 🔑 Secrets tab in the left panel. <img src=\"https://storage.googleapis.com/generativeai-downloads/images/secrets.jpg\" alt=\"The Secrets tab is found on the left panel.\" width=50%>\n",
"2. Create a new secret with the name `HF_TOKEN`.\n",
"3. Copy/paste your HF token key into the Value input box of `HF_TOKEN`.\n",
"4. Toggle the button on the left to allow notebook access to the secret."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "hK-qUiQGQbe5"
},
"outputs": [],
"source": [
"import os\n",
"from google.colab import userdata\n",
"\n",
"# Note: `userdata.get` is a Colab API. If you're not using Colab, set the env\n",
"# vars as appropriate for your system.\n",
"os.environ[\"HF_TOKEN\"] = userdata.get(\"HF_TOKEN\")"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "5DrB-trDQsgE"
},
"source": [
"### Install dependencies\n",
"\n",
"First, you must install the Python package for mistral.rs."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "LLjxxhk2Qrf_"
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Collecting mistralrs==0.3.2\n",
" Downloading mistralrs-0.3.2-cp310-cp310-manylinux_2_34_x86_64.whl.metadata (1.7 kB)\n",
"Downloading mistralrs-0.3.2-cp310-cp310-manylinux_2_34_x86_64.whl (14.4 MB)\n",
"\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m14.4/14.4 MB\u001b[0m \u001b[31m51.7 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
"\u001b[?25hInstalling collected packages: mistralrs\n",
"Successfully installed mistralrs-0.3.2\n"
]
}
],
"source": [
"!pip install mistralrs==0.3.2"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "_wsT7Eb_1esZ"
},
"source": [
"## Overview\n",
"\n",
" In this notebook, you will explore how to prompt the Gemma 2 model using the Python APIs of mistral.rs. It's divided into the following sections:\n",
"\n",
"1. Load the Gemma 2 model\n",
"2. Response generation\n",
"3. Non-streaming chat completion\n",
"4. Streaming chat completion\n",
"5. Inference with grammar\n",
"\n",
"Additionally, you can run inference on Gemma 2 using the Rust APIs and command-line interface of mistral.rs. Read more about them in the [mistral.rs documentation](https://github.com/EricLBuehler/mistral.rs?tab=readme-ov-file#get-started-fast-).\n"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "tMfQBPRYsknZ"
},
"source": [
"## Load the Gemma 2 model\n",
"\n",
"In this section, you will learn how to load the Gemma 2 model from Hugging Face Hub using the mistral.rs Python APIs.\n",
"\n",
"First, create an instance of the `Which` class specifying the model to load. Set `model_id` to the Hugging Face Hub repository ID of the desired Gemma 2 model variant. Specify `Architecture.Gemma2` for the arch parameter.\n",
"\n",
"Create an instance of the `Runner` class by providing the created `Which` instance. Set `token_source` to `env:HF_TOKEN` to indicate downloading the model from Hugging Face Hub using your existing Hugging Face token.\n",
"\n",
"Models can also be loaded from the local file system. Refer to the [mistral.rs get models guide](https://github.com/EricLBuehler/mistral.rs?tab=readme-ov-file#getting-models) for more details."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "VYFDjCBBtBhS"
},
"outputs": [],
"source": [
"from mistralrs import Runner, Which, Architecture\n",
"\n",
"runner = Runner(\n",
" which=Which.Plain(\n",
" model_id=\"google/gemma-2-2b-it\",\n",
" arch=Architecture.Gemma2,\n",
" ),\n",
" token_source=\"env:HF_TOKEN\",\n",
")"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "YoXapKI3vE4m"
},
"source": [
"## Response generation\n",
"\n",
"Next, you'll send a prompt to Gemma 2 for inference using mistral.rs.\n",
"\n",
"To achieve this, create an instance of the `CompletionRequest` class, specifying your desired prompt in the `prompt` parameter. Set the model parameter to \"gemma2\". mistral.rs allows you to configure common LLM parameters like `top_p`, `max_tokens`, and `temperature` within `CompletionRequest`. You can find the full list in the [class definition of `CompletionRequest`](https://github.com/EricLBuehler/mistral.rs/blob/458dc5f447161904ee9191fdec1dc5d4f039af54/mistralrs-pyo3/mistralrs.pyi#L44).\n",
"\n",
"Finally, call the `send_completion_request` method on the `runner` object, passing the newly created `CompletionRequest` instance as the `request` argument."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "niIzVlPbeSvO"
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"A black hole is a region of spacetime where gravity is so strong that nothing, not even light, can escape. \n",
"\n",
"Here's a breakdown:\n",
"\n",
"**What causes a black hole?**\n",
"\n",
"* **Stellar Collapse:** When a massive star runs out of fuel, it collapses under its own gravity. This collapse can create a black hole if the star's mass is enough.\n",
"* **Supermassive Black Holes:** These are found at the centers of most galaxies, including our own Milky Way. Their formation is still a mystery, but they likely formed from the merging of smaller black holes or from the collapse of massive gas clouds.\n",
"\n",
"**Key Features of a Black Hole:**\n",
"\n",
"* **Event Horizon:** This is the boundary around a black hole beyond which escape is impossible. Once something crosses the event horizon, it's trapped forever.\n",
"* **Singularity:** This is the theoretical point at the center of a black hole where all the matter is compressed into an infinitely small point.\n",
"* **Gravitational Pull:** Black holes have immense gravitational pull, which warps spacetime around them.\n",
"\n",
"**How do we know black holes exist?**\n",
"\n",
"* **Gravitational Effects:** We can detect black holes by observing their gravitational effects on nearby stars and gas.\n",
"* **X-ray Emissions:** As matter falls into a black hole, it heats up and emits X-rays, which can be detected by telescopes.\n",
"* **Gravitational Waves:** When black holes collide, they create ripples in spacetime called gravitational waves, which can be detected by specialized instruments.\n",
"\n",
"**Interesting Facts:**\n",
"\n",
"* Black holes are not actually \"holes\" in space, but rather regions of extreme density and gravity.\n",
"* The size of a black hole is measured by its \"Schwarzschild radius,\" which is the distance from the center at which the escape velocity equals the speed of light.\n",
"* Black holes are still a subject of active research, and scientists are constantly learning more about them.\n",
"\n",
"\n",
"Let me know if you have any other questions! \n",
"\n",
"Usage {\n",
" completion_tokens: 415,\n",
" prompt_tokens: 7,\n",
" total_tokens: 422,\n",
" avg_tok_per_sec: 1.3271735,\n",
" avg_prompt_tok_per_sec: 1.4326648,\n",
" avg_compl_tok_per_sec: 1.3255271,\n",
" total_time_sec: 317.969,\n",
" total_prompt_time_sec: 4.886,\n",
" total_completion_time_sec: 313.083,\n",
"}\n"
]
}
],
"source": [
"from mistralrs import CompletionRequest\n",
"\n",
"request = CompletionRequest(\n",
" model=\"gemma2\",\n",
" prompt=\"What is a black hole?\",\n",
" max_tokens=512,\n",
" top_p=0.1,\n",
" temperature=0.1,\n",
")\n",
"\n",
"response = runner.send_completion_request(request)\n",
"\n",
"print(response.choices[0].text)\n",
"print(response.usage)"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "vu1jh_nLMXyS"
},
"source": [
"## Non-streaming chat completion\n",
"\n",
"In addition to basic prompting, mistral.rs also supports multi-turn conversations with LLMs.\n",
"\n",
"For multi-turn conversations, you'll maintain a list of dictionaries, each representing a single message in the conversation history with Gemma 2. These dictionaries should include a key named `role` specifying whether the user or the assistant (model) generated the message.\n",
"\n",
"Create an instance of `ChatCompletionRequest` by passing the conversation history to the `messages` parameter. Then, invoke the `send_chat_completion_request` method of the `runner` with the newly created `ChatCompletionRequest` instance as the `request` argument.\n",
"\n",
"You can also configure any other model parameters you want to use. Refer to the the [class definition of `ChatCompletionRequest`](https://github.com/EricLBuehler/mistral.rs/blob/458dc5f447161904ee9191fdec1dc5d4f039af54/mistralrs-pyo3/mistralrs.pyi#L11)."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "ffjhPyHpvEJM"
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Neil Armstrong and Buzz Aldrin were both American astronauts. They represented the **United States** during the Apollo 11 mission. \n",
"\n",
"Usage {\n",
" completion_tokens: 29,\n",
" prompt_tokens: 74,\n",
" total_tokens: 103,\n",
" avg_tok_per_sec: 2.808376,\n",
" avg_prompt_tok_per_sec: 4.5204644,\n",
" avg_compl_tok_per_sec: 1.4281493,\n",
" total_time_sec: 36.676,\n",
" total_prompt_time_sec: 16.37,\n",
" total_completion_time_sec: 20.306,\n",
"}\n"
]
}
],
"source": [
"from mistralrs import ChatCompletionRequest\n",
"\n",
"messages = [\n",
" {\n",
" \"role\": \"user\",\n",
" \"content\": \"Who are the first humans to land on moon?\",\n",
" },\n",
" {\n",
" \"role\": \"assistant\",\n",
" \"content\": \"The first humans to land on the Moon were **Neil Armstrong and Buzz Aldrin** of the Apollo 11 mission on **July 20, 1969**.\",\n",
" },\n",
" {\n",
" \"role\": \"user\",\n",
" \"content\": \"Which country did they belong to?\",\n",
" },\n",
"]\n",
"\n",
"request = ChatCompletionRequest(\n",
" model=\"gemma2\",\n",
" messages=messages,\n",
" max_tokens=256,\n",
" presence_penalty=1.0,\n",
" top_p=0.1,\n",
" temperature=0.1,\n",
")\n",
"\n",
"response = runner.send_chat_completion_request(request)\n",
"\n",
"print(response.choices[0].message.content)\n",
"print(response.usage)"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "7W-F2n_xhWRl"
},
"source": [
"Notice how the model generated the response for the last query from the user based on the previous conversation history."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "qZehF79UOVWg"
},
"source": [
"## Streaming chat completion\n",
"\n",
"mistral.rs also supports obtaining streamed responses from the model during multi-turn conversations. To enable streaming, set the `stream` parameter of `ChatCompletionRequest` to `True`.\n",
"\n",
"The `send_chat_completion_request` function will return an iterable object. You can access each chunk of the streamed response by iterating over this object.\n",
"\n",
"Pass the conversation history (`messages`) you created earlier to the `ChatCompletionRequest` instance with `stream=True`."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "nxNg6cXFog29"
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Neil Armstrong and Buzz Aldrin were both American astronauts. They represented the **United States** during the Apollo 11 mission. \n"
]
}
],
"source": [
"request = ChatCompletionRequest(\n",
" model=\"gemma2\",\n",
" messages=messages,\n",
" max_tokens=256,\n",
" presence_penalty=1.0,\n",
" top_p=0.1,\n",
" temperature=0.1,\n",
" stream=True,\n",
")\n",
"\n",
"response = runner.send_chat_completion_request(request)\n",
"\n",
"for chunk in response:\n",
" print(chunk.choices[0].delta.content, end=\"\")"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "3RALnmwQP6WQ"
},
"source": [
"## Inference with grammar\n",
"\n",
"Specifying a grammar when creating a `ChatCompletionRequest` or `CompletionRequest` allows you to constrain the model's output to a specific format.\n",
"\n",
"To ensure the model generates responses that match a regular expression, set the parameter `grammar_type` to \"regex\" and `grammar` to the desired regular expression string.\n",
"\n",
"The following code snippet illustrates how to restrict the model's response to two-digit numbers using regular expression grammar:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "Gi2FGoXJv-K9"
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"13\n",
"Usage {\n",
" completion_tokens: 5,\n",
" prompt_tokens: 31,\n",
" total_tokens: 36,\n",
" avg_tok_per_sec: 3.115804,\n",
" avg_prompt_tok_per_sec: 3.5123498,\n",
" avg_compl_tok_per_sec: 1.8328446,\n",
" total_time_sec: 11.554,\n",
" total_prompt_time_sec: 8.826,\n",
" total_completion_time_sec: 2.728,\n",
"}\n"
]
}
],
"source": [
"request = CompletionRequest(\n",
" model=\"gemma2\",\n",
" prompt=\"What is the next number in the fibonnaci series: 1, 1, 2, 3, 5, 8 ?\",\n",
" grammar_type=\"regex\",\n",
" grammar=r\"\\d{2}\",\n",
" top_p=0.1,\n",
" temperature=0.1,\n",
")\n",
"\n",
"response = runner.send_completion_request(request)\n",
"\n",
"print(response.choices[0].text)\n",
"print(response.usage)"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "X72NYX92nv40"
},
"source": [
"mistral.rs offers support for various grammar types beyond regular expressions. Refer to the [mistral.rs documentation](https://github.com/EricLBuehler/mistral.rs/tree/master?tab=readme-ov-file#description) for more details."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "vlRAC8OFa1pQ"
},
"source": [
"These are just a few examples of how you can perform inference with Gemma 2 using mistral.rs. To explore its capabilities further, you can refer to the [mistral.rs documentation](https://github.com/EricLBuehler/mistral.rs/tree/master?tab=readme-ov-file#--mistralrs).\n"
]
}
],
"metadata": {
"colab": {
"name": "[Gemma_2]Using_with_mistral_rs.ipynb",
"toc_visible": true
},
"kernelspec": {
"display_name": "Python 3",
"name": "python3"
}
},
"nbformat": 4,
"nbformat_minor": 0
}